{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "词表大小: 715\n"
     ]
    }
   ],
   "source": [
    "from tokenizers import Tokenizer\n",
    "\n",
    "# 加载 BPE Tokenizer\n",
    "tokenizer = Tokenizer.from_file(\"bpe_tokenizer.json\")\n",
    "\n",
    "# 获取词表大小\n",
    "vocab_size = tokenizer.get_vocab_size()\n",
    "print(\"词表大小:\", vocab_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练数据大小: torch.Size([483, 50]) torch.Size([483, 50])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "def create_training_data(text, tokenizer, seq_length=50):\n",
    "    \"\"\"将文本转换为 Token ID，并生成训练样本\"\"\"\n",
    "    token_ids = tokenizer.encode(text).ids\n",
    "    inputs, targets = [], []\n",
    "\n",
    "    for i in range(len(token_ids) - seq_length):\n",
    "        inputs.append(token_ids[i:i+seq_length])\n",
    "        targets.append(token_ids[i+1:i+seq_length+1])\n",
    "\n",
    "    return torch.tensor(inputs), torch.tensor(targets)\n",
    "\n",
    "# 读取训练数据\n",
    "file_paths = [\"data1.txt\", \"data2.txt\"]\n",
    "text = \"\\n\".join([open(f, \"r\", encoding=\"utf-8\").read() for f in file_paths])\n",
    "\n",
    "# 生成训练样本\n",
    "inputs, targets = create_training_data(text, tokenizer)\n",
    "print(\"训练数据大小:\", inputs.shape, targets.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 5.341959476470947\n",
      "Epoch 1, Loss: 5.159451961517334\n",
      "Epoch 2, Loss: 4.773972034454346\n",
      "Epoch 3, Loss: 3.6134955883026123\n",
      "Epoch 4, Loss: 2.6023621559143066\n",
      "Epoch 5, Loss: 2.7072513103485107\n",
      "Epoch 6, Loss: 1.595573902130127\n",
      "Epoch 7, Loss: 1.4765725135803223\n",
      "Epoch 8, Loss: 1.4026365280151367\n",
      "Epoch 9, Loss: 1.3479602336883545\n",
      "Epoch 10, Loss: 1.3279881477355957\n",
      "Epoch 11, Loss: 1.2963348627090454\n",
      "Epoch 12, Loss: 1.2420926094055176\n",
      "Epoch 13, Loss: 1.196912169456482\n",
      "Epoch 14, Loss: 1.2506626844406128\n",
      "Epoch 15, Loss: 1.1052536964416504\n",
      "Epoch 16, Loss: 1.1789007186889648\n",
      "Epoch 17, Loss: 1.1432737112045288\n",
      "Epoch 18, Loss: 1.5103932619094849\n",
      "Epoch 19, Loss: 1.0925631523132324\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "# 创建 PyTorch 数据加载器\n",
    "dataset = TensorDataset(inputs, targets)\n",
    "dataloader = DataLoader(dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "# 定义 Transformer Decoder\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model, num_heads, num_layers, dim_feedforward, max_len):\n",
    "        \"\"\"\n",
    "        Transformer Decoder 网络架构\n",
    "        :param vocab_size: 词表大小\n",
    "        :param d_model: 词嵌入维度\n",
    "        :param num_heads: 多头注意力头数\n",
    "        :param num_layers: 解码器层数\n",
    "        :param dim_feedforward: FeedForward 层维度\n",
    "        :param max_len: 序列最大长度\n",
    "        :return: None\n",
    "\n",
    "        注意：\n",
    "        - TransformerDecoder 架构由多个 TransformerDecoderLayer 组成，每个 TransformerDecoderLayer 包含一个 MultiHeadAttention 和一个 FeedForward 层。\n",
    "        - TransformerDecoderLayer 的参数包括 d_model（词嵌入维度），num_heads（多头注意力头数），dim_feedforward（FeedForward 层维度）。\n",
    "        - TransformerDecoder 的参数包括 vocab_size（词表大小），d_model（词嵌入维度），num_heads（多头注意力头数），num_layers（解码器层数），dim_feedforward（FeedForward 层维度），max_len（序列最大长度）。\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model)\n",
    "        self.positional_encoding = nn.Parameter(torch.randn(1, max_len, d_model))\n",
    "        \n",
    "        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward)\n",
    "        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)\n",
    "        self.fc_out = nn.Linear(d_model, vocab_size)\n",
    "\n",
    "    def forward(self, input_seq, memory):\n",
    "        seq_len = input_seq.size(1)\n",
    "        embedded = self.embedding(input_seq) + self.positional_encoding[:, :seq_len, :]\n",
    "        memory = self.embedding(memory) + self.positional_encoding[:, :memory.size(1), :]\n",
    "        output = self.transformer_decoder(embedded, memory)\n",
    "        return self.fc_out(output)\n",
    "\n",
    "# 超参数\n",
    "d_model = 64  \n",
    "num_heads = 4  \n",
    "num_layers = 3  \n",
    "dim_feedforward = 128  \n",
    "max_len = 100  \n",
    "\n",
    "model = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, dim_feedforward, max_len)\n",
    "\n",
    "# 训练参数\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.95)\n",
    "\n",
    "def train(model, dataloader, epochs=20):\n",
    "    for epoch in range(epochs):\n",
    "        for batch_inputs, batch_targets in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            output = model(batch_inputs, batch_inputs)\n",
    "            loss = nn.CrossEntropyLoss()(output.view(-1, vocab_size), batch_targets.view(-1))\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 避免梯度爆炸\n",
    "            optimizer.step()\n",
    "        scheduler.step()\n",
    "        print(f\"Epoch {epoch}, Loss: {loss.item()}\")\n",
    "\n",
    "# 训练模型\n",
    "train(model, dataloader, epochs=20)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成文本: he ll o ” remarks came\n"
     ]
    }
   ],
   "source": [
    "def generate_text(model, tokenizer, start_text, max_length=50):\n",
    "    \"\"\"用 BPE 词表生成文本\"\"\"\n",
    "    model.eval()\n",
    "    input_tokens = tokenizer.encode(\"<bos> \" + start_text).ids\n",
    "    input_tensor = torch.tensor([input_tokens])\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _ in range(max_length):\n",
    "            output = model(input_tensor, input_tensor)\n",
    "            next_token = output.argmax(dim=-1)[:, -1].item()\n",
    "            input_tokens.append(next_token)\n",
    "            input_tensor = torch.tensor([input_tokens])\n",
    "\n",
    "            # 如果遇到 <eos> 结束标记，则停止\n",
    "            if tokenizer.decode([next_token]) == \"<eos>\":\n",
    "                break\n",
    "\n",
    "    return tokenizer.decode(input_tokens)\n",
    "\n",
    "# 测试生成\n",
    "print(\"生成文本:\", generate_text(model, tokenizer, \"hello\", max_length=3))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
